package com.code2roc.fastboot.framework.socket;

import com.alibaba.fastjson.JSONObject;
import com.code2roc.fastboot.framework.auth.TokenModel;
import com.code2roc.fastboot.framework.auth.TokenUtil;
import com.code2roc.fastboot.framework.model.Result;
import com.code2roc.fastboot.framework.util.BeanUtil;
import com.code2roc.fastboot.framework.util.LogUtil;
import com.code2roc.fastboot.framework.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @ServerEndpoint 注解是一个类层次的注解，它的功能主要是将目前的类定义成一个websocket服务器端,
 * 注解的值将被用于监听用户连接的终端访问URL地址,客户端可以通过这个URL来连接到WebSocket服务器端
 * 即 @ServerEndpoint 可以把当前类变成websocket服务类
 */
//访问服务端的url地址
@Component
@ServerEndpoint("/frame/socket/{token}")
public class BaseWebSocketServer {
    private static Logger log = LoggerFactory.getLogger(BaseWebSocketServer.class);
    @Autowired
    private static TokenUtil tokenUtil = BeanUtil.getBean(TokenUtil.class);
    private static ConcurrentHashMap<String, BaseWebSocketServer> webSocketMap = new ConcurrentHashMap<>();
    private String token = "";
    private Session session;
    private String socketIDKey = "";

    @OnOpen
    public void onOpen(Session session, @PathParam("token") String token) {
        this.token = token;
        this.session = session;
        socketIDKey = token + "|" + session.getId() + "|" + tokenUtil.getTokenModel(token).getUserID();
        try {
            if (tokenUtil.checkTokenValid(token)) {
                if (webSocketMap.containsKey(socketIDKey)) {
                    webSocketMap.remove(socketIDKey);
                    webSocketMap.put(socketIDKey, this);
                } else {
                    webSocketMap.put(socketIDKey, this);
                }
                sendMessage(socketIDKey, Result.okResult().setMsg("连接成功"));
                LogUtil.writeLog("websocket/connect", "sqllog", "用户连接成功:【用户名】：" + tokenUtil.getTokenModel(token).getUserName()
                        + "【token】：" + token + "【sessionid】:" + session.getId() + "【当前连接总数】：" + webSocketMap.size());
            } else {
                sendMessage(socketIDKey, Result.errorResult().setMsg("token无效"));
            }
        } catch (Exception e) {
            LogUtil.writeLog("websocket/connect", "sqllog", "用户网络异常:" + tokenUtil.getTokenModel(token).getUserName());
        }
    }

    @OnClose
    public void onClose() {
        if (webSocketMap.containsKey(socketIDKey)) {
            webSocketMap.remove(socketIDKey);
        }
        LogUtil.writeLog("websocket/connect", "socketconnect", "用户退出:" + tokenUtil.getTokenModel(token).getUserName());
    }

    @OnError
    public void onError(Session session, Throwable error) {
        LogUtil.writeLog("websocket/connect", "socketconnect", "用户错误:" + tokenUtil.getTokenModel(token).getUserName() + ",原因:" + error.getMessage());
        error.printStackTrace();
    }

    @OnMessage
    public void onMessage(String message, Session session) {
        LogUtil.writeLog("websocket/connect", "socketconnect", "用户消息:" + tokenUtil.getTokenModel(token).getUserName() + ",报文:" + message);
        if (!StringUtil.isEmpty(message)) {
            if (message.equals("heartbeat")) {
                if (tokenUtil.checkTokenValid(token)) {
                    sendMessage(socketIDKey, Result.okResult().setMsg("heartbeat response"));
                    log.debug("接收用户【" + tokenUtil.getTokenModel(token).getUserName() + "】心跳消息");
                } else {
                    if (webSocketMap.containsKey(socketIDKey)) {
                        webSocketMap.remove(socketIDKey);
                    }
                }
            }
        }
    }

    private static void sendMessage(String socketIDKey, Result result) {
        try {
            if (webSocketMap.containsKey(socketIDKey)) {
                webSocketMap.get(socketIDKey).session.getBasicRemote().sendText(JSONObject.toJSONString(result));
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void sendMessageToUser(String userID, Result result) {
        try {
            for (String socketIDKey : webSocketMap.keySet()) {
                if (socketIDKey.endsWith(userID)) {
                    String token = socketIDKey.split("\\|")[0];
                    if (tokenUtil.checkTokenValid(token)) {
                        TokenModel tokenModel = tokenUtil.getTokenModel(token);
                        if (tokenModel.getUserID().equals(userID)) {
                            webSocketMap.get(socketIDKey).session.getBasicRemote().sendText(JSONObject.toJSONString(result));
                        }
                    } else {
                        if (webSocketMap.containsKey(socketIDKey)) {
                            webSocketMap.remove(socketIDKey);
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.debug("捕获websocket异常");
            e.printStackTrace();
        }
    }

    public static boolean checkConnectExist(String token){
        boolean flag = false;
        for (String socketIDKey : webSocketMap.keySet()) {
            if (socketIDKey.startsWith(token)) {
                flag = true;
                break;
            }
        }
        return flag;
    }

    public static void deleteConnect(String token){
        for (String socketIDKey : webSocketMap.keySet()) {
            if (socketIDKey.startsWith(token)) {
                webSocketMap.remove(socketIDKey);
                break;
            }
        }
    }
}
